package orm

import (
	"errors"
	glog "gitee.com/guolianyu/kit/skit/orm/logger"
	"gitee.com/guolianyu/kit/skit/orm/mysql"
	"gitee.com/guolianyu/kit/skit/orm/postgres"
	"gitee.com/guolianyu/kit/skit/third_party/config"
	klog "github.com/go-kratos/kratos/v2/log"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"gorm.io/plugin/dbresolver"
	"gorm.io/plugin/opentelemetry/tracing"
	"gorm.io/plugin/prometheus"
)

var (
	ErrUnsupportedType         = errors.New("unsupported database type")
	ErrUnsupportedResolverType = errors.New("unsupported resolver type")
)

type Driver string

func (d Driver) String() string {
	return string(d)
}

const (
	MySQL       Driver = "mysql"
	PostgresSQL Driver = "postgres"
)

type orm struct {
	cfg     *config.Database
	logger  logger.Interface
	plugins []gorm.Plugin
	metrics []prometheus.MetricsCollector
}

type Option func(*orm)

func NewGormClient(cfg *config.Database, opts ...Option) (db *gorm.DB, cleanup func(), err error) {

	if cfg == nil {
		return nil, func() {}, nil
	}

	o := orm{
		cfg:     cfg,
		logger:  glog.New(klog.GetLogger()),
		plugins: make([]gorm.Plugin, 0),
		metrics: make([]prometheus.MetricsCollector, 0),
	}

	c := o.cfg

	for _, opt := range opts {
		opt(&o)
	}

	//添加gorm 监控插件
	o.plugins = append(o.plugins,
		//trace
		//Prometheus 添加指标
		tracing.NewPlugin(tracing.WithDBName(c.Name)),
		prometheus.New(prometheus.Config{
			DBName:          c.Name, //使用DBName 作为指标label
			RefreshInterval: 15,     //指标刷新频率(默认15秒)
			MetricsCollector: []prometheus.MetricsCollector{
				&RequestMetrics{
					Database: c.Name,
				},
			},
		}),
	)

	switch Driver(c.Driver) {
	case MySQL:
		db, err = mysql.New(c, o.logger)
		if err != nil {
			return
		}
	case PostgresSQL:
		db, err = postgres.New(c, o.logger)
		if err != nil {
			return
		}
	default:
		klog.Errorf("[DATABASE]init database fail, err: %v", ErrUnsupportedType)
		return nil, nil, ErrUnsupportedType
	}

	//读写分离
	if len(c.Resolvers) > 0 {
		if err = registerResolver(db, Driver(c.Driver), c.Resolvers); err != nil {
			klog.Errorf("[DATABASE]init database Resolvers fail, err: %v", err)
			return nil, nil, err
		}
	}

	//使用插件，监控db
	for _, plugin := range o.plugins {
		err = db.Use(plugin)
		if err != nil {
			return nil, nil, err
		}
	}

	klog.Info("[DATABASE]init database resources seccuss")
	cleanup = func() {
		klog.Info("[DATABASE]closing the database resources")
		sqlDB, err1 := db.DB()
		if err1 != nil {
			klog.Error(err1)
		}

		if err2 := sqlDB.Close(); err2 != nil {
			klog.Error(err2)
		}
	}

	return db, cleanup, nil
}

// 读写分离
func registerResolver(db *gorm.DB, driver Driver, resolvers []*config.Database_Resolver) error {
	if len(resolvers) > 0 {
		var (
			sources  = make([]gorm.Dialector, 0, len(resolvers))
			replicas = make([]gorm.Dialector, 0, len(resolvers))
		)

		for _, resolver := range resolvers {
			dial, err := BuildDialector(driver, resolver)
			if err != nil {
				return err
			}
			switch ResolverType(resolver.Type) {
			case Source:
				sources = append(sources, dial)
			case Replica:
				replicas = append(replicas, dial)
			default:
				return ErrUnsupportedResolverType
			}
		}

		return db.Use(dbresolver.Register(dbresolver.Config{
			Sources:  sources,
			Replicas: replicas,
			Policy:   dbresolver.RandomPolicy{},
		}))
	}

	return nil
}
